import pygame
import random
import time
from pygame.locals import *
from characters import *
from custom_events import *
from boss import BossB2
from spritesheet import SpriteSheet


class Scene(pygame.sprite.LayeredUpdates):

    def __init__(self, screen, custom_events=None):
        super().__init__()
        self.runner = None
        self.bonus = pygame.sprite.Group()
        self.enemies = pygame.sprite.Group()
        self.enemy_bullets = pygame.sprite.Group()
        self.jet_bullets = pygame.sprite.Group()
        self.custom_events = custom_events
        self.screen = screen
        self.background = GameBackgroud(self,
                                        pygame.image.load("images/bg_river.jpg")
                                        )

        self.rect = self.background.rect

        self.clouds = Clouds(self, (0, -300), (-4, 20), 1)
        self.add(self.background, layer=0)
        self.add(self.clouds, layer=10)

        self.fadelayer = FadeEffect(self);
        self.add(self.fadelayer, layer=20)

        self.jet = Fighter(self)
        self.add(self.jet, layer=1)

        self.b2 = BossB2(self)
        self.add(self.b2, layer=1)
        self.enemies.add(self.b2)

        self.b2.set_target(self.jet)

        self.load_explosion_images()
        pass

    def load_explosion_images(self):
        self.explosion_anim = {}
        self.explosion_anim['big'] = [[],[],[],[]]
        self.explosion_anim['small'] = [[],[],[],[]]
        self.explosion_anim['medium'] = [[],[],[],[]]
        self.explosion_anim['huge'] = [[],[],[],[]]

        for i in range(8):
            filename = 'images/explo/mini_explo_{}.png'.format(i)
            img = pygame.image.load(filename).convert_alpha()
            img_sm = pygame.transform.scale(img, (48, 48))
            self.explosion_anim['small'][0].append(img_sm)
            self.explosion_anim['small'][1].append(
                pygame.transform.flip(img_sm, True, False))
            self.explosion_anim['small'][2].append(
                pygame.transform.flip(img_sm, False, True))
            self.explosion_anim['small'][3].append(
                pygame.transform.flip(img_sm, True, True))


        ss = SpriteSheet('images/explo/e2.png')
        self.explosion_anim['medium'][0] = ss.load_grid_images(5, 5)

        ss = SpriteSheet('images/explo/e3.png')
        self.explosion_anim['medium'][1] = ss.load_grid_images(5, 5)

        ss = SpriteSheet('images/explo/e5.png')
        self.explosion_anim['medium'][2] = ss.load_grid_images(4, 4)

        ss = SpriteSheet('images/explo/e8.png')
        self.explosion_anim['medium'][3] = ss.load_grid_images(5, 5)

        for i in range(4):
            for j in range(len(self.explosion_anim['medium'][i])):
                self.explosion_anim['big'][i].append( pygame.transform.flip(
                        pygame.transform.scale(self.explosion_anim['medium'][i][j], (96, 96)),
                        False,
                        True
                    )
                )
                self.explosion_anim['huge'][i].append( pygame.transform.flip(
                        pygame.transform.scale(self.explosion_anim['medium'][i][j], (64*3, 64*3)),
                        True,
                        True
                    )
                )

    def set_runner(self, runner):
        self.runner = runner

    def set_rect(self, rect):
        self.rect = rect

    def play_music(self, path):
        pygame.mixer.music.load(path)
        pygame.mixer.music.set_volume(0.1)
        pygame.mixer.music.play(-1)

    def load_sounds(self):
        self.sound_hit = pygame.mixer.Sound("sound/hit.wav")
        self.sound_gun_shots = pygame.mixer.Sound("sound/gun.wav")
        self.sound_explo = pygame.mixer.Sound("sound/DeathFlash.wav")
        self.sound_explo_remote = pygame.mixer.Sound("sound/distant_explo.wav")
        self.sound_explo_air = pygame.mixer.Sound("sound/hjm_big_explosion_3.wav")


    def music_fade_out(self):
        pygame.mixer.music.fadeout(500)

    def draw_elements(self):
        for e in self.sprites():
            e.draw()
        #  self.draw(self.screen)


    def update_elements(self):
        #  for e in self.sprites():
            #  e.update()
        self.update()

    def handle_event(self):
        for event in pygame.event.get():
            if event.type == pygame.QUIT:
                self.runner.running = False

            if event.type == pygame.KEYUP:
                self.jet.current_state = 'shoot_ahead'

            if event.type == self.custom_events.add_cloud_event:
                self.clouds.reset_position()

            if event.type == self.custom_events.b2_main_pod_on_event:
                self.b2.set_main_pod_ready(True)
            if event.type == self.custom_events.b2_main_pod_off_event:
                self.b2.set_main_pod_ready(False)

            if event.type == self.custom_events.b2_left_pod_on_event:
                self.b2.set_left_wing_pod_ready(True)
            if event.type == self.custom_events.b2_left_pod_off_event:
                self.b2.set_left_wing_pod_ready(False)

            if event.type == self.custom_events.b2_right_pod_on_event:
                self.b2.set_right_wing_pod_ready(True)
            if event.type == self.custom_events.b2_right_pod_off_event:
                self.b2.set_right_wing_pod_ready(False)

        self.key_pressed()
        self.jet.set_ahead()

        pass


    def key_pressed(self):
        key_pressed = pygame.key.get_pressed()
        if key_pressed[K_SPACE]:
            self.jet.shoot()
        if key_pressed[K_UP]:
            self.jet.current_state = 'shoot_ahead'
            self.jet.move_up()
        if key_pressed[K_DOWN]:
            self.jet.current_state = 'shoot_ahead'
            self.jet.move_down()
        if key_pressed[K_LEFT]:
            self.jet.current_state = 'left'
            self.jet.move_left()
        if key_pressed[K_RIGHT]:
            self.jet.current_state = 'right'
            self.jet.move_right()
        # ESC quitting current episode
        if key_pressed[K_ESCAPE]:
            self.fadelayer.active = True

        if self.fadelayer.done:
            self.runner.running = False



    def detect_collision(self):
        pass

